{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([1., 1., 1., 1.])\n",
    "cos = torch.tensor([0.5403, 0.6479, 0.7318, 0.7965])\n",
    "sin = torch.tensor([8.4147e-01, 7.6172e-01, 6.8156e-01, 6.0469e-01])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rotate_half(x):\n",
    "    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n",
    "    x1 = x[..., : x.shape[-1] // 2]\n",
    "    x2 = x[..., x.shape[-1] // 2 :]\n",
    "    return torch.cat((-x2, x1), dim=-1)\n",
    "\n",
    "def apply(x):\n",
    "    \"\"\"Applies the rotation to the input.\"\"\"\n",
    "    x = x * cos + rotate_half(x) * sin\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1.4012)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x[3] * cos[3] + rotate_half(x)[3] * sin[3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1., -1.,  1.,  1.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rotate_half(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.3012, -0.1138,  1.4134,  1.4012])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "apply(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0000,  0.0000,  0.6816,  0.0000],\n",
      "        [ 0.0000,  0.0000,  0.0000,  0.6047],\n",
      "        [-0.8415,  0.0000,  0.0000,  0.0000],\n",
      "        [ 0.0000, -0.7617,  0.0000,  0.0000]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([-0.3012, -0.1138,  1.4134,  1.4012])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def hardcoded_cos_rot_sin(x, cos, sin):\n",
    "    m_cos = torch.tensor([\n",
    "        [cos[0], 0, 0, 0],\n",
    "        [0, cos[1], 0, 0],\n",
    "        [0, 0, cos[2], 0],\n",
    "        [0, 0, 0, cos[3]]        \n",
    "    ])\n",
    "    m_rot_sin = torch.tensor([\n",
    "        [0., 0., sin[2], 0.],\n",
    "        [0., 0., 0., sin[3]],\n",
    "        [-sin[0], 0., 0., 0.],\n",
    "        [0., -sin[1], 0., 0.]\n",
    "    ])\n",
    "    print(m_rot_sin)\n",
    "    m_cos_plus_rot_sin = torch.tensor([\n",
    "        [cos[0], 0., sin[2], 0.],\n",
    "        [0., cos[1], 0., sin[3]],\n",
    "        [-sin[0], 0., cos[2], 0.],\n",
    "        [0., -sin[1], 0., cos[3]]\n",
    "    ])\n",
    "    \n",
    "    return m_cos_plus_rot_sin\n",
    "\n",
    "x @ hardcoded_cos_rot_sin(x, cos, sin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.3012, -0.1138,  1.4134,  1.4012])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def general_cos_rot_sin(cos, sin):\n",
    "    \"\"\" Extending the above to arbitrary lengths of cos and sin \"\"\"\n",
    "    m_cos = torch.diag(cos)\n",
    "    m_sin = torch.diag(sin)\n",
    "    d = len(sin)\n",
    "    m_rot_sin = torch.cat([m_sin[d // 2:], -m_sin[:d // 2]])\n",
    "    return m_cos + m_rot_sin\n",
    "\n",
    "x @ general_cos_rot_sin(cos, sin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "def precompute_freqs(dim: int, end: int, theta: float = 10000.0):\n",
    "    \"\"\"\n",
    "    Precompute the frequency tensor for sine and cosine values with given dimensions.\n",
    "\n",
    "    Args:\n",
    "        dim (int): Dimension of the frequency tensor.\n",
    "        end (int): End index for precomputing frequencies.\n",
    "        theta (float, optional): Scaling factor for frequency computation. Grok-1 uses 10000.\n",
    "\n",
    "    Returns:\n",
    "        Tuple[torch.Tensor, torch.Tensor]: Tensors containing cosine and sine values.\n",
    "    \"\"\"\n",
    "    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n",
    "    t = torch.arange(end)\n",
    "    freqs = torch.outer(t, freqs).float()\n",
    "    cos, sin = torch.cos(freqs), torch.sin(freqs)\n",
    "    return cos, sin\n",
    "\n",
    "\n",
    "def freq_row_to_rotation_matrix(cos_row, sin_row):\n",
    "    \"\"\"\n",
    "    Transform cos/sin frequency rows to a dim x dim rotation matrix\n",
    "    that implements cos + rotate_half * sin\n",
    "    \"\"\"\n",
    "\n",
    "    d = len(sin_row)\n",
    "    m_cos = torch.diag(cos_row)\n",
    "    m_sin = torch.diag(sin_row)\n",
    "    d = len(sin_row)\n",
    "    m_rot_sin = torch.cat([m_sin[d // 2:], -m_sin[:d // 2]])\n",
    "    return m_cos + m_rot_sin\n",
    "\n",
    "\n",
    "def get_rotation_mat(dhead, end):\n",
    "    cos, sin = precompute_freqs(dhead, end)\n",
    "    rot_mat = [freq_row_to_rotation_matrix(c, s) for c, s in zip(cos, sin)]\n",
    "    return rot_mat\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 4])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cos, sin = precompute_freqs(8, 100)\n",
    "\n",
    "cos.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = get_rotation_mat(128, 16384)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 64])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
